#!/usr/bin/env python2
from __future__ import print_function
import sys
import os
import os.path
import re

from argparse import ArgumentParser
from argparse import FileType
from pprint import pprint
import shlex
import subprocess
from subprocess import check_call
from subprocess import Popen
from subprocess import PIPE
from subprocess import CalledProcessError
from datetime import datetime
from time import sleep
from errno import ESRCH
try:
    import termcolor
except ImportError:
    termcolor = None
from random import random
import commands
import multiprocessing
from contextlib import closing


MESSAGES_TO_RETRY = [
    "DB::Exception: ZooKeeper session has been expired",
    "Coordination::Exception: Connection loss",
]


def remove_control_characters(s):
    """
    https://github.com/html5lib/html5lib-python/issues/96#issuecomment-43438438
    """
    def str_to_int(s, default, base=10):
        if int(s, base) < 0x10000:
            return unichr(int(s, base))
        return default
    s = re.sub(r"&#(\d+);?", lambda c: str_to_int(c.group(1), c.group(0)), s)
    s = re.sub(r"&#[xX]([0-9a-fA-F]+);?", lambda c: str_to_int(c.group(1), c.group(0), base=16), s)
    s = re.sub(r"[\x00-\x08\x0b\x0e-\x1f\x7f]", "", s)
    return s


def run_single_test(args, ext, server_logs_level, client_options, case_file, stdout_file, stderr_file):

    # print(client_options)

    params = {
        'client': args.client_with_database,
        'logs_level': server_logs_level,
        'options': client_options,
        'test': case_file,
        'stdout': stdout_file,
        'stderr': stderr_file,
    }

    pattern = '{test} > {stdout} 2> {stderr}'

    if ext == '.sql':
        pattern = "{client} --send_logs_level={logs_level} --testmode --multiquery {options} < " + pattern

    command = pattern.format(**params)
    #print(command)

    proc = Popen(command, shell=True, env=os.environ)
    start_time = datetime.now()
    while (datetime.now() - start_time).total_seconds() < args.timeout and proc.poll() is None:
        sleep(0.01)

    total_time = (datetime.now() - start_time).total_seconds()

    # Normalize randomized database names in stdout, stderr files.
    os.system("LC_ALL=C sed -i -e 's/{test_db}/default/g' {file}".format(test_db=args.database, file=stdout_file))
    os.system("LC_ALL=C sed -i -e 's/{test_db}/default/g' {file}".format(test_db=args.database, file=stderr_file))

    stdout = open(stdout_file, 'r').read() if os.path.exists(stdout_file) else ''
    stdout = unicode(stdout, errors='replace', encoding='utf-8')
    stderr = open(stderr_file, 'r').read() if os.path.exists(stderr_file) else ''
    stderr = unicode(stderr, errors='replace', encoding='utf-8')

    return proc, stdout, stderr, total_time


def need_retry(stderr):
    return any(msg in stderr for msg in MESSAGES_TO_RETRY)


def get_processlist(client_cmd):
    try:
        return subprocess.check_output("{} --query 'SHOW PROCESSLIST FORMAT Vertical'".format(client_cmd), shell=True)
    except:
        return "" #  server seems dead


# collect server stacktraces using gdb
def get_stacktraces_from_gdb(server_pid):
    cmd = "gdb -batch -ex 'thread apply all backtrace' -p {}".format(server_pid)
    try:
        return subprocess.check_output(cmd, shell=True)
    except Exception as ex:
        return "Error occured while receiving stack traces from gdb: {}".format(str(ex))


# collect server stacktraces from system.stack_trace table
def get_stacktraces_from_clickhouse(client):
    try:
        return subprocess.check_output("{} --allow_introspection_functions=1 --query \"SELECT arrayStringConcat(arrayMap(x, y -> concat(x, ': ', y), arrayMap(x -> addressToLine(x), trace), arrayMap(x -> demangle(addressToSymbol(x)), trace)), '\n') as trace FROM system.stack_trace format Vertical\"".format(client), shell=True)
    except Exception as ex:
        return "Error occured while receiving stack traces from client: {}".format(str(ex))


def get_server_pid(server_tcp_port):
    cmd = "lsof -i tcp:{port} -s tcp:LISTEN -Fp | awk '/^p[0-9]+$/{{print substr($0, 2)}}'".format(port=server_tcp_port)
    try:
        output = subprocess.check_output(cmd, shell=True)
        if output:
            return int(output)
        else:
            return None # server dead
    except Exception as ex:
        return None


def colored(text, args, color=None, on_color=None, attrs=None):
       if termcolor and (sys.stdout.isatty() or args.force_color):
           return termcolor.colored(text, color, on_color, attrs)
       else:
           return text


SERVER_DIED = False
exit_code = 0


# def run_tests_array(all_tests, suite, suite_dir, suite_tmp_dir, run_total):
def run_tests_array(all_tests_with_params):
    all_tests, suite, suite_dir, suite_tmp_dir, run_total = all_tests_with_params
    global exit_code
    global SERVER_DIED

    OP_SQUARE_BRACKET = colored("[", args, attrs=['bold'])
    CL_SQUARE_BRACKET = colored("]", args, attrs=['bold'])

    MSG_FAIL = OP_SQUARE_BRACKET + colored(" FAIL ", args, "red", attrs=['bold']) + CL_SQUARE_BRACKET
    MSG_UNKNOWN = OP_SQUARE_BRACKET + colored(" UNKNOWN ", args, "yellow", attrs=['bold']) + CL_SQUARE_BRACKET
    MSG_OK = OP_SQUARE_BRACKET + colored(" OK ", args, "green", attrs=['bold']) + CL_SQUARE_BRACKET
    MSG_SKIPPED = OP_SQUARE_BRACKET + colored(" SKIPPED ", args, "cyan", attrs=['bold']) + CL_SQUARE_BRACKET

    passed_total = 0
    skipped_total = 0
    failures_total = 0
    failures = 0
    failures_chain = 0

    client_options = get_additional_client_options(args)

    def print_test_time(test_time):
        if args.print_time:
            print(" {0:.2f} sec.".format(test_time), end='')

    if len(all_tests):
        print("\nRunning {} {} tests.".format(len(all_tests), suite) + "\n")

    for case in all_tests:
        if SERVER_DIED:
            break

        case_file = os.path.join(suite_dir, case)
        (name, ext) = os.path.splitext(case)

        try:
            sys.stdout.write("{0:72}".format(name + ": "))
            if run_total == 1:
                sys.stdout.flush()

            if args.skip and any(s in name for s in args.skip):
                print(MSG_SKIPPED + " - skip")
                skipped_total += 1
            elif not args.zookeeper and ('zookeeper' in name
                    or 'replica' in name):
                print(MSG_SKIPPED + " - no zookeeper")
                skipped_total += 1
            elif not args.shard and ('shard' in name
                    or 'distributed' in name
                    or 'global' in name):
                print(MSG_SKIPPED + " - no shard")
                skipped_total += 1
            elif not args.no_long and ('long' in name
                    # Tests for races and deadlocks usually are runned in loop
                    #  for significant amount of time
                    or 'deadlock' in name
                    or 'race' in name):
                print(MSG_SKIPPED + " - no long")
                skipped_total += 1
            else:
                disabled_file = os.path.join(suite_dir, name) + '.disabled'

                if os.path.exists(disabled_file) and not args.disabled:
                    message = open(disabled_file, 'r').read()
                    print(MSG_SKIPPED + " - " + message)
                else:

                    if args.testname:
                        clickhouse_proc = Popen(shlex.split(args.client_with_database), stdin=PIPE, stdout=PIPE, stderr=PIPE)
                        clickhouse_proc.communicate("SELECT 'Running test {suite}/{case} from pid={pid}';".format(pid = os.getpid(), case = case, suite = suite))

                    reference_file = os.path.join(suite_dir, name) + '.reference'
                    stdout_file = os.path.join(suite_tmp_dir, name) + '.stdout'
                    stderr_file = os.path.join(suite_tmp_dir, name) + '.stderr'

                    proc, stdout, stderr, total_time = run_single_test(args, ext, server_logs_level, client_options, case_file, stdout_file, stderr_file)
                    if proc.returncode is None:
                        try:
                            proc.kill()
                        except OSError as e:
                            if e.errno != ESRCH:
                                raise

                        failures += 1
                        print(MSG_FAIL, end='')
                        print_test_time(total_time)
                        print(" - Timeout!")
                    else:
                        counter = 1
                        while proc.returncode != 0 and need_retry(stderr):
                            proc, stdout, stderr, total_time = run_single_test(args, ext, server_logs_level, client_options, case_file, stdout_file, stderr_file)
                            sleep(2**counter)
                            counter += 1
                            if counter > 6:
                                break

                        if proc.returncode != 0:
                            failures += 1
                            failures_chain += 1
                            print(MSG_FAIL, end='')
                            print_test_time(total_time)
                            print(" - return code {}".format(proc.returncode))

                            if stderr:
                                print(stderr.encode('utf-8'))

                            if args.stop and ('Connection refused' in stderr or 'Attempt to read after eof' in stderr) and not 'Received exception from server' in stderr:
                                SERVER_DIED = True

                        elif stderr:
                            failures += 1
                            failures_chain += 1
                            print(MSG_FAIL, end='')
                            print_test_time(total_time)
                            print(" - having stderror:\n{}".format(stderr.encode('utf-8')))
                        elif 'Exception' in stdout:
                            failures += 1
                            failures_chain += 1
                            print(MSG_FAIL, end='')
                            print_test_time(total_time)
                            print(" - having exception:\n{}".format(stdout.encode('utf-8')))
                        elif not os.path.isfile(reference_file):
                            print(MSG_UNKNOWN, end='')
                            print_test_time(total_time)
                            print(" - no reference file")
                        else:
                            result_is_different = subprocess.call(['diff', '-q', reference_file, stdout_file], stdout = PIPE)

                            if result_is_different:
                                diff = Popen(['diff', '-U', str(args.unified), reference_file, stdout_file], stdout = PIPE).communicate()[0]
                                failures += 1
                                print(MSG_FAIL, end='')
                                print_test_time(total_time)
                                print(" - result differs with reference:\n{}".format(diff))
                            else:
                                passed_total += 1
                                failures_chain = 0
                                print(MSG_OK, end='')
                                print_test_time(total_time)
                                print()
                                if os.path.exists(stdout_file):
                                    os.remove(stdout_file)
                                if os.path.exists(stderr_file):
                                    os.remove(stderr_file)
        except KeyboardInterrupt as e:
            print(colored("Break tests execution", args, "red"))
            raise e
        except:
            import traceback
            exc_type, exc_value, tb = sys.exc_info()
            failures += 1
            print("{0} - Test internal error: {1}\n{2}\n{3}".format(MSG_FAIL, exc_type.__name__, exc_value, "\n".join(traceback.format_tb(tb, 10))))

        if failures_chain >= 20:
            break

    failures_total = failures_total + failures

    if failures_total > 0:
        print(colored("\nHaving {failures_total} errors! {passed_total} tests passed. {skipped_total} tests skipped.".format(passed_total = passed_total, skipped_total = skipped_total, failures_total = failures_total), args, "red", attrs=["bold"]))
        exit_code = 1
    else:
        print(colored("\n{passed_total} tests passed. {skipped_total} tests skipped.".format(passed_total = passed_total, skipped_total = skipped_total), args, "green", attrs=["bold"]))


server_logs_level = "warning"


def main(args):
    global SERVER_DIED
    global exit_code
    global server_logs_level

    def is_data_present():
        clickhouse_proc = Popen(shlex.split(args.client), stdin=PIPE, stdout=PIPE, stderr=PIPE)
        (stdout, stderr) = clickhouse_proc.communicate("EXISTS TABLE test.hits")
        if clickhouse_proc.returncode != 0:
            raise CalledProcessError(clickhouse_proc.returncode, args.client, stderr)

        return stdout.startswith('1')

    base_dir = os.path.abspath(args.queries)
    tmp_dir = os.path.abspath(args.tmp)

    # Keep same default values as in queries/shell_config.sh
    os.environ.setdefault("CLICKHOUSE_BINARY", args.binary)
    #os.environ.setdefault("CLICKHOUSE_CLIENT", args.client)
    os.environ.setdefault("CLICKHOUSE_CONFIG", args.configserver)
    if args.configclient:
        os.environ.setdefault("CLICKHOUSE_CONFIG_CLIENT", args.configclient)
    os.environ.setdefault("CLICKHOUSE_TMP", tmp_dir)
    os.environ.setdefault("CLICKHOUSE_DATABASE", args.database)

    # Force to print server warnings in stderr
    # Shell scripts could change logging level
    os.environ.setdefault("CLICKHOUSE_CLIENT_SERVER_LOGS_LEVEL", server_logs_level)

    if args.zookeeper is None:
        code, out = commands.getstatusoutput(args.extract_from_config +" --try --config " + args.configserver + ' --key zookeeper | grep . | wc -l')
        try:
            if int(out) > 0:
                args.zookeeper = True
            else:
                args.zookeeper = False
        except ValueError:
            args.zookeeper = False

    if args.shard is None:
        code, out = commands.getstatusoutput(args.extract_from_config + " --try --config " + args.configserver + ' --key listen_host | grep -E "127.0.0.2|::"')
        if out:
            args.shard = True
        else:
            args.shard = False

    clickhouse_proc_create = Popen(shlex.split(args.client), stdin=PIPE, stdout=PIPE, stderr=PIPE)
    clickhouse_proc_create.communicate("CREATE DATABASE IF NOT EXISTS " + args.database)
    if args.database != "test":
        clickhouse_proc_create = Popen(shlex.split(args.client), stdin=PIPE, stdout=PIPE, stderr=PIPE)
        clickhouse_proc_create.communicate("CREATE DATABASE IF NOT EXISTS test")

    def is_test_from_dir(suite_dir, case):
        case_file = os.path.join(suite_dir, case)
        (name, ext) = os.path.splitext(case)
        return os.path.isfile(case_file) and (ext == '.sql' or ext == '.sh' or ext == '.py')

    def sute_key_func(item):
       if args.order == 'random':
             return random()

       if -1 == item.find('_'):
           return 99998

       prefix, suffix = item.split('_', 1)

       try:
           return int(prefix), suffix
       except ValueError:
           return 99997

    total_tests_run = 0
    for suite in sorted(os.listdir(base_dir), key=sute_key_func):
        if SERVER_DIED:
            break

        suite_dir = os.path.join(base_dir, suite)
        suite_re_obj = re.search('^[0-9]+_(.*)$', suite)
        if not suite_re_obj: #skip .gitignore and so on
            continue

        suite_tmp_dir = os.path.join(tmp_dir, suite)
        if not os.path.exists(suite_tmp_dir):
            os.makedirs(suite_tmp_dir)

        suite = suite_re_obj.group(1)
        if os.path.isdir(suite_dir):

            if 'stateful' in suite and not args.no_stateful and not is_data_present():
                print("Won't run stateful tests because test data wasn't loaded.")
                continue
            if 'stateless' in suite and args.no_stateless:
                print("Won't run stateless tests because they were manually disabled.")
                continue
            if 'stateful' in suite and args.no_stateful:
                print("Won't run stateful tests because they were manually disabled.")
                continue

            # Reverse sort order: we want run newest test first.
            # And not reverse subtests
            def key_func(item):
                if args.order == 'random':
                    return random()

                reverse = 1 if args.order == 'asc' else -1

                if -1 == item.find('_'):
                    return 99998

                prefix, suffix = item.split('_', 1)

                try:
                    return reverse * int(prefix), suffix
                except ValueError:
                    return 99997

            all_tests = os.listdir(suite_dir)
            all_tests = filter(lambda case: is_test_from_dir(suite_dir, case), all_tests)
            if args.test:
                all_tests = [t for t in all_tests if any([re.search(r, t) for r in args.test])]
            all_tests.sort(key=key_func)

            run_n, run_total = args.parallel.split('/')
            run_n = float(run_n)
            run_total = float(run_total)
            tests_n = len(all_tests)
            if run_total > tests_n:
                run_total = tests_n
            if run_n > run_total:
                continue

            jobs = args.jobs
            if jobs > tests_n:
                jobs = tests_n
            if jobs > run_total:
                run_total = jobs

            all_tests_array = []
            for n in range(1, 1 + int(run_total)):
                start = int(tests_n / run_total * (n - 1))
                end = int(tests_n / run_total * n)
                all_tests_array.append([all_tests[start : end], suite, suite_dir, suite_tmp_dir, run_total])

            if jobs > 1:
                with closing(multiprocessing.Pool(processes=jobs)) as pool:
                    pool.map(run_tests_array, all_tests_array)
                    pool.terminate()
            else:
                run_tests_array(all_tests_array[int(run_n)-1])

            total_tests_run += tests_n

    if args.hung_check:
        processlist = get_processlist(args.client_with_database)
        if processlist:
            print(colored("\nFound hung queries in processlist:", args, "red", attrs=["bold"]))
            print(processlist)

            clickhouse_tcp_port = os.getenv("CLICKHOUSE_PORT_TCP", '9000')
            server_pid = get_server_pid(clickhouse_tcp_port)
            if server_pid:
                print("\nLocated ClickHouse server process {} listening at TCP port {}".format(server_pid, clickhouse_tcp_port))
                print("\nCollecting stacktraces from system.stacktraces table:")
                print(get_stacktraces_from_clickhouse(args.client))
                print("\nCollecting stacktraces from all running threads with gdb:")
                print(get_stacktraces_from_gdb(server_pid))
            else:
                print(
                    colored(
                        "\nUnable to locate ClickHouse server process listening at TCP port {}. "
                        "It must have crashed or exited prematurely!".format(clickhouse_tcp_port),
                        args, "red", attrs=["bold"]))

            exit_code = 1
        else:
            print(colored("\nNo queries hung.", args, "green", attrs=["bold"]))

    if total_tests_run == 0:
        print("No tests were run.")
        sys.exit(1)

    sys.exit(exit_code)


def find_binary(name):
    if os.path.exists(name) and os.access(name, os.X_OK):
        return True
    paths = os.environ.get("PATH").split(':')
    for path in paths:
        if os.access(os.path.join(path, name), os.X_OK):
            return True

    # maybe it wasn't in PATH
    if os.access(os.path.join('/usr/local/bin', name), os.X_OK):
        return True
    if os.access(os.path.join('/usr/bin', name), os.X_OK):
        return True
    return False


def get_additional_client_options(args):
    if args.client_option:
        return ' '.join('--' + option for option in args.client_option)

    return ''


def get_additional_client_options_url(args):
    if args.client_option:
        return '&'.join(args.client_option)

    return ''


if __name__ == '__main__':
    parser=ArgumentParser(description='ClickHouse functional tests')
    parser.add_argument('-q', '--queries', help='Path to queries dir')
    parser.add_argument('--tmp', help='Path to tmp dir')
    parser.add_argument('-b', '--binary', default='clickhouse', help='Path to clickhouse binary or name of binary in PATH')
    parser.add_argument('-c', '--client', help='Client program')
    parser.add_argument('--extract_from_config', help='extract-from-config program')
    parser.add_argument('--configclient', help='Client config (if you use not default ports)')
    parser.add_argument('--configserver', default= '/etc/clickhouse-server/config.xml', help='Preprocessed server config')
    parser.add_argument('-o', '--output', help='Output xUnit compliant test report directory')
    parser.add_argument('-t', '--timeout', type=int, default=600, help='Timeout for each test case in seconds')
    parser.add_argument('test', nargs='*', help='Optional test case name regex')
    parser.add_argument('-d', '--disabled', action='store_true', default=False, help='Also run disabled tests')
    parser.add_argument('--stop', action='store_true', default=None, dest='stop', help='Stop on network errors')
    parser.add_argument('--order', default='desc', choices=['asc', 'desc', 'random'], help='Run order')
    parser.add_argument('--testname', action='store_true', default=None, dest='testname', help='Make query with test name before test run')
    parser.add_argument('--hung-check', action='store_true', default=False)
    parser.add_argument('--force-color', action='store_true', default=False)
    parser.add_argument('--database', help='Database for tests (random name test_XXXXXX by default)')
    parser.add_argument('--parallel', default='1/1', help='One parallel test run number/total')
    parser.add_argument('-j', '--jobs', default=1, nargs='?', type=int, help='Run all tests in parallel')
    parser.add_argument('-U', '--unified', default=3, type=int, help='output NUM lines of unified context')

    parser.add_argument('--no-stateless', action='store_true', help='Disable all stateless tests')
    parser.add_argument('--no-stateful', action='store_true', help='Disable all stateful tests')
    parser.add_argument('--skip', nargs='+', help="Skip these tests")
    parser.add_argument('--no-long', action='store_false', dest='no_long', help='Do not run long tests')
    parser.add_argument('--client-option', nargs='+', help='Specify additional client argument')
    parser.add_argument('--print-time', action='store_true', dest='print_time', help='Print test time')
    group=parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--zookeeper', action='store_true', default=None, dest='zookeeper', help='Run zookeeper related tests')
    group.add_argument('--no-zookeeper', action='store_false', default=None, dest='zookeeper', help='Do not run zookeeper related tests')
    group=parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--shard', action='store_true', default=None, dest='shard', help='Run sharding related tests (required to clickhouse-server listen 127.0.0.2 127.0.0.3)')
    group.add_argument('--no-shard', action='store_false', default=None, dest='shard', help='Do not run shard related tests')

    args = parser.parse_args()

    if args.queries is None and os.path.isdir('queries'):
        args.queries = 'queries'
    elif args.queries is None:
        if (os.path.isdir('/usr/local/share/clickhouse-test/queries')):
            args.queries = '/usr/local/share/clickhouse-test/queries'
        if (args.queries is None and os.path.isdir('/usr/share/clickhouse-test/queries')):
            args.queries = '/usr/share/clickhouse-test/queries'
        if args.tmp is None:
            args.tmp = '/tmp/clickhouse-test'
    if args.queries is None:
        print("Failed to detect path to the queries directory. Please specify it with '--queries' option.", file=sys.stderr)
        exit(1)
    if args.tmp is None:
        args.tmp = args.queries
    if args.client is None:
        if find_binary(args.binary + '-client'):
            args.client = args.binary + '-client'
        elif find_binary(args.binary):
            args.client = args.binary + ' client'
        else:
            print("No 'clickhouse' binary found in PATH", file=sys.stderr)
            parser.print_help()
            exit(1)

        if args.configclient:
            args.client += ' --config-file=' + args.configclient
        if os.getenv("CLICKHOUSE_HOST"):
            args.client += ' --host=' + os.getenv("CLICKHOUSE_HOST")
        if os.getenv("CLICKHOUSE_PORT_TCP"):
            args.client += ' --port=' + os.getenv("CLICKHOUSE_PORT_TCP")
        if os.getenv("CLICKHOUSE_DATABASE"):
            args.client += ' --database=' + os.getenv("CLICKHOUSE_DATABASE")

    if args.client_option:
        # Set options for client
        if 'CLICKHOUSE_CLIENT_OPT' in os.environ:
           os.environ['CLICKHOUSE_CLIENT_OPT'] += ' '
        else:
           os.environ['CLICKHOUSE_CLIENT_OPT'] = ''

        os.environ['CLICKHOUSE_CLIENT_OPT'] += get_additional_client_options(args)

        # Set options for curl
        if 'CLICKHOUSE_URL_PARAMS' in os.environ:
           os.environ['CLICKHOUSE_URL_PARAMS'] += '&'
        else:
           os.environ['CLICKHOUSE_URL_PARAMS'] = ''

        os.environ['CLICKHOUSE_URL_PARAMS'] += get_additional_client_options_url(args)


    args.client_with_database = args.client
    if not args.database:
        def random_str(length=6):
            import random
            import string
            alphabet = string.ascii_lowercase + string.digits
            return ''.join(random.choice(alphabet) for _ in range(length))
        args.database = 'test_{suffix}'.format(suffix=random_str())
    args.client_with_database += ' --database=' + args.database

    if args.extract_from_config is None:
        if os.access(args.binary + '-extract-from-config', os.X_OK):
            args.extract_from_config = args.binary + '-extract-from-config'
        else:
            args.extract_from_config = args.binary + ' extract-from-config'

    if args.jobs is None:
        args.jobs = multiprocessing.cpu_count()

    main(args)
